import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import safety_gym
import gym
import time
import  core
from utils.logx import EpochLogger
from utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from torch.nn.functional import softplus
torch.autograd.set_detect_anomaly(True)
import sys
import sysv_ipc
import torch.optim as optim
import torch.nn.functional as F
import ignite.handlers.param_scheduler as IGN
import pandas as pd
import matplotlib.pyplot as plt
from torch.nn.utils import clip_grad_norm_
from mpi4py import MPI
import psutil
import os
os.chdir('/home/user/safety-starter-agents/safe_rl/PPO_Lagrangian_PyTorch/data/PPO-POINT-TRAIN')
import gc
import random
from torch.distributions.normal import Normal
class Safety_NN(nn.Module):
    def __init__(self, n_state, n_class):
        super(Safety_NN, self).__init__()
        self.layer1 = nn.Linear(n_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_class)
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)


class PPOBuffer:
    """
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.97):
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.cadv_buf = np.zeros(size, dtype=np.float32)

        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.crew_buf = np.zeros(size, dtype=np.float32)

        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.cret_buf = np.zeros(size, dtype=np.float32)

        self.val_buf = np.zeros(size, dtype=np.float32)
        self.cval_buf = np.zeros(size, dtype=np.float32)
        
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

        #buf.store(   o, a, r, c, v,vc, logp)
    def store(self, obs, act, rew, crew, val,cval, logp):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.crew_buf[self.ptr] = crew

        self.val_buf[self.ptr] = val
        self.cval_buf[self.ptr] = cval

        self.logp_buf[self.ptr] = logp
        self.ptr += 1

    def finish_path(self, last_val=0, last_cval=0):
        """
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.

        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        crews = np.append(self.crew_buf[path_slice], last_cval)

        vals = np.append(self.val_buf[path_slice], last_val)
        cvals = np.append(self.cval_buf[path_slice], last_cval)
        
        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        cdeltas = crews[:-1] + self.gamma * cvals[1:] - cvals[:-1]

        self.adv_buf[path_slice] = core.discount_cumsum(deltas, self.gamma * self.lam)
        self.cadv_buf[path_slice] = core.discount_cumsum(cdeltas, self.gamma * self.lam)

        
        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = core.discount_cumsum(rews, self.gamma)[:-1]
        self.cret_buf[path_slice] = core.discount_cumsum(crews, self.gamma)[:-1]
        
        self.path_start_idx = self.ptr

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = mpi_statistics_scalar(self.adv_buf)
        cadv_mean, cadv_std = mpi_statistics_scalar(self.cadv_buf)

        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        self.cadv_buf = (self.cadv_buf - cadv_mean) #/ adv_std

        data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf, cret=self.cret_buf,
                    adv=self.adv_buf, cadv=self.cadv_buf, logp=self.logp_buf)
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}


def ppo(env_fn, actor_critic=core.MLPActorCritic_ppo_point_train, ac_kwargs=dict(), seed=0, 
        steps_per_epoch=4000, epochs=50, gamma=0.99, clip_ratio=0.2, pi_lr=3e-5,
        vf_lr=1e-4, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000,
        target_kl=0.01, logger_kwargs=dict(), save_freq=200,
        render=False,
        n_NN = 1,
        n_sample_of_action = 252,
        n_class = 2,
        n_action = 2,
        file_intest_number_array = [5000000, 5000000], 
        storage_intest_number_array = [2500000, 2500000],
        pof_section = 3,
        gradient_offset = True,
        batch_training = True,
        unit_intest_batch_size_array = [5000,5000],
        intest_trajectory_step = 60,
        default_action_margin = 0.05,
        pof_checkpoint = 1,
        pof_checkpoint_load_epoch = -1,
        hazard_check_margin = 0.935,
        agent_checkpoint_mode = "error",
        ppo_checkpoint = "-1",
        cpu = -1
        ):

    setup_pytorch_for_mpi()

    intest_batch_size_array = []
    for i in range(len(unit_intest_batch_size_array)):
        intest_batch_size_array.append(unit_intest_batch_size_array[i])

    # Path setting
    pof_data_path = "/home/user/POF_data_" + str(pof_section) +"/"
    pof_checkpoint_path = pof_data_path + "checkpoints/pof-checkpoint-"+str(pof_checkpoint)+"_epoch-"+str(pof_checkpoint_load_epoch)+".pt"
    ppo_checkpoint_path = agent_checkpoint_mode+"/checkpoint/"+agent_checkpoint_mode+"_"+ppo_checkpoint+".pt"

    # Set up logger and save configuration
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())
    
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    # Split communicator based on shared memory availability
    shared_comm = comm.Split_type(MPI.COMM_TYPE_SHARED)

    # Allocate shared memory window
    win_size = MPI.DOUBLE.Get_size()*700000000 if rank == 0 else 0
    win = MPI.Win.Allocate_shared(win_size, MPI.DOUBLE.Get_size(), comm=shared_comm)

    # Create a numpy array that points to the shared memory
    buf, _ = win.Shared_query(0)
    
    # Allocate shared memory window
    wintwo_size = MPI.INT.Get_size()*4 if rank == 0 else 0
    wintwo = MPI.Win.Allocate_shared(wintwo_size, MPI.INT.Get_size(), comm=shared_comm)

    # Create a numpy array that points to the shared memory
    buftwo, _ = wintwo.Shared_query(0)
    
    # Random seed
    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_space.shape
    n_observation = obs_dim[0]
    act_dim = env.action_space.shape
    
    device = torch.device('cpu')
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)
    intest_batch_size = intest_batch_size_array[0] + intest_batch_size_array[1]
    current_safe_intest_index_list = [np.zeros([intest_batch_size_array[0]]) for _ in range(n_NN)]
    current_unsafe_intest_index_list = [np.zeros([intest_batch_size_array[1]]) for _ in range(n_NN)]
    queue_data_list, queue_label_list, queue_counter_list = [], [], []
    env=env_fn()
    queue_data_list=[]
    queue_label_list=[]
    queue_counter_list=[]
    queue_intest_cycle_list = 0
    n_unlabeleddata_list = 0
    # IPC setting
    key_1=(proc_id()%cpu+1)*(61011+pof_section)
    key_2=(proc_id()%cpu+1)*(61022+pof_section)
    pof_shared_memory_1 = sysv_ipc.SharedMemory(key=key_1, flags=sysv_ipc.IPC_CREAT, size=820000) #####
    sem_1 = sysv_ipc.Semaphore(key_1, 0)
    pof_shared_memory_2 = sysv_ipc.SharedMemory(key=key_2, flags=sysv_ipc.IPC_CREAT, size=360000) #####
    sem_2 = sysv_ipc.Semaphore(key_2, 0)
    
    # Create actor-critic module helpmefour
    ac_ppo = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac_ppo.train()

    # Set up optimizers for policy and value function helpmeeight
    pi_lr = 1.5e-5
    pi_optimizer = Adam(ac_ppo.pi.parameters(), lr=pi_lr)
    vf_lr = 0.5e-4
    vf_optimizer = Adam(ac_ppo.v.parameters(), lr=vf_lr)
    cvf_optimizer = Adam(ac_ppo.vc.parameters(),lr=vf_lr)

    # SNN setting
    SNN_list = [Safety_NN(n_observation+n_action, n_class).to(device) for _ in range(n_NN)]
    warmup_local_epoch = 900000
    main_local_epoch = 17100000
    
    try:
        backup = torch.load(pof_checkpoint_path)
        if ppo_checkpoint == "-1":
            ac_ppo.load_state_dict(backup["ac_ppo"])
            ac_ppo.train()
            pi_optimizer.load_state_dict(backup["pi_optimizer"])
            vf_optimizer.load_state_dict(backup["vf_optimizer"])
            cvf_optimizer.load_state_dict(backup["cvf_optimizer"])
        # SNN setting
        optimizer_list = [optim.AdamW(SNN_list[i].parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=True) for i in range(n_NN)]
        scheduler_list = [optim.lr_scheduler.LinearLR(optimizer_list[i], start_factor=1, end_factor=1e-3, total_iters=main_local_epoch) for i in range(n_NN)]
        scheduler_list_with_warmup = [IGN.create_lr_scheduler_with_warmup(scheduler_list[i], warmup_start_value=0., warmup_duration=warmup_local_epoch) for i in range(n_NN)]
        for i in range(n_NN):
            temp_name = "SNN_" + str(i)
            SNN_list[i].load_state_dict(backup[temp_name])
            temp_name = "Soptim_" + str(i)
            optimizer_list[i].load_state_dict(backup[temp_name])
            SNN_list[i].train()
            temp_name = "Ssch_" + str(i)
            scheduler_list[i].load_state_dict(backup[temp_name])
            temp_name = "Sschww_" + str(i)
            scheduler_list_with_warmup[i].load_state_dict(backup[temp_name])
        # IT setting
        storage_intest_number = backup["storage_intest_number"]
        storage_intest_state_list = backup["storage_intest_state_list"]
        storage_intest_label_list = backup["storage_intest_label_list"]
        storage_intest_accgd_list = [np.zeros([storage_intest_number, n_class]) for _ in range(n_NN)]
        safe_intest_index_list = backup["safe_intest_index_list"]
        unsafe_intest_index_list = backup["unsafe_intest_index_list"]
        current_intest_state_list = backup["current_intest_state_list"]
        current_intest_label_list = backup["current_intest_label_list"]
        current_intest_size = backup["current_intest_size"]
        # ETC setting
        backup_epoch = backup["epoch"]+1
        context = backup["context"]
        print("    LOAD CHECKPOINT %s" %(pof_checkpoint_path))
    except Exception as e:
        print(e)
        if ppo_checkpoint != "-1":
            backup = torch.load(ppo_checkpoint_path)
            backup_dict = backup.state_dict()
            selected_net = ['pi', 'v', 'vc']
            selecte_dict = {k: v for k, v in backup_dict.items() if any(net_name in k for net_name in selected_net)}
            ac_ppo.load_state_dict(selecte_dict, strict=False)
        print("    LOAD PPO %s" %(ppo_checkpoint_path))
        log_std = -2 * np.ones(act_dim[0], dtype=np.float32)
        ac_ppo.pi.log_std.data = torch.as_tensor(log_std)
        # SNN setting
        optimizer_list = [optim.AdamW(SNN_list[i].parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=True) for i in range(n_NN)]
        scheduler_list = [optim.lr_scheduler.LinearLR(optimizer_list[i], start_factor=1, end_factor=1e-3, total_iters=main_local_epoch) for i in range(n_NN)]
        scheduler_list_with_warmup = [IGN.create_lr_scheduler_with_warmup(scheduler_list[i], warmup_start_value=0., warmup_duration=warmup_local_epoch) for i in range(n_NN)]
        for i in range(n_NN): SNN_list[i].train()
        # IT setting
        storage_intest_number = storage_intest_number_array[0] + storage_intest_number_array[1]
        storage_intest_state_list = []
        storage_intest_label_list =[]
        storage_intest_accgd_list = [np.zeros([storage_intest_number, n_class]) for _ in range(n_NN)]
        safe_intest_index_list = [list() for _ in range(n_NN)]
        unsafe_intest_index_list = [list() for _ in range(n_NN)]
        current_intest_state_list = [np.zeros([intest_batch_size, n_observation+n_action]) for _ in range(n_NN)]
        current_intest_label_list = [np.zeros([intest_batch_size, 1]) for _ in range(n_NN)]
        current_intest_size = intest_batch_size
        # ETC setting
        backup_epoch = 0 
        context = [storage_intest_number_array[0], storage_intest_number_array[1]]    
        intest_dataset_path = agent_checkpoint_mode+"/intest/intest_obs.csv"
        obss_df = pd.read_csv(intest_dataset_path, header=None)
        obss_np = obss_df.to_numpy()
        assert(n_action == 2)
        for nn_order in range(n_NN):
            if proc_id() == 0: storage_intest_state_list.append([])
            storage_intest_label_list.append([])
            for intest_order in range(storage_intest_number_array[0]):
                if proc_id() == 0: storage_intest_state_list[nn_order].append(obss_np[intest_order])
                storage_intest_label_list[nn_order].append([0.0])
                safe_intest_index_list[nn_order].append(intest_order)
            for intest_order in range(storage_intest_number_array[1]):
                if proc_id() == 0: storage_intest_state_list[nn_order].append(obss_np[file_intest_number_array[0]+intest_order])
                storage_intest_label_list[nn_order].append([1.0])
                unsafe_intest_index_list[nn_order].append(storage_intest_number_array[0]+intest_order)
            if proc_id() == 0: storage_intest_state_list=np.array(storage_intest_state_list)
            storage_intest_label_list=np.array(storage_intest_label_list)
            if proc_id() == 0:
                print(storage_intest_state_list.shape)
                assert storage_intest_state_list[nn_order].shape[0] == storage_intest_number
            assert storage_intest_label_list[nn_order].shape[0] == storage_intest_number
            assert len(safe_intest_index_list[nn_order]) == storage_intest_number_array[0]
            assert len(unsafe_intest_index_list[nn_order]) == storage_intest_number_array[1]  
            print("    CANNOT LOAD CHECKPOINT")
        del obss_np
        del obss_df
    tmp=storage_intest_state_list
    del storage_intest_state_list
    storage_intest_state_list=np.ndarray(buffer=buf, dtype='d', shape=(1,storage_intest_number_array[0]+storage_intest_number_array[1],62))
    print("before_copy")
    if proc_id()==0:
        win.Lock(0, lock_type=MPI.LOCK_EXCLUSIVE)
        for i in range(storage_intest_number_array[0]+storage_intest_number_array[1]):
           for j in range(62):
               storage_intest_state_list[0,i,j]=tmp[0,i,j]
        print("after_copy")
        win.Unlock(0)
    win.Fence()
    del tmp
    gc.collect()
    queue_intest_index_and_number = np.ndarray(buffer=buftwo, dtype='i', shape=(4)) # [index0, index1, number0, number1]
    if proc_id()==0:
        wintwo.Lock(0, lock_type=MPI.LOCK_EXCLUSIVE)
        for i in range(4): queue_intest_index_and_number[i] = 0
        wintwo.Unlock(0)
    wintwo.Fence()
    
    for i in range(n_NN): 
        sync_params(SNN_list[i])


    def queue_storage_intest(selected_action, safety_label, state_reshape, step, goal_checking):  
        nonlocal storage_intest_state_list
        nonlocal queue_intest_cycle_list, n_unlabeleddata_list, queue_data_list, queue_label_list, queue_counter_list
        nonlocal queue_intest_index_and_number, intest_rejected_count

        queue_intest_cycle_list+=1
        if queue_intest_cycle_list >= 10*10000000//(storage_intest_number_array[0]+storage_intest_number_array[1]):
            queue_intest_cycle_list = 0
            if safety_label == 0:
                queue_label_list.append(safety_label)
                queue_data_list.append(np.concatenate((state_reshape, np.array([selected_action])), axis=1))
                queue_counter_list.append(0)
                n_unlabeleddata_list+=1
            else: 
                intest_rejected_count += 1

        if goal_checking:
            while n_unlabeleddata_list > 0:
                pop_label = queue_label_list.pop(0)
                assert pop_label == 0
                pop_state = queue_data_list.pop(0)
                queue_counter_list.pop(0)
                n_unlabeleddata_list -= 1
                queue_intest_index_tmp = queue_intest_index_and_number[pop_label]
                storage_intest_state_list[0][safe_intest_index_list[0][queue_intest_index_tmp]] = pop_state
                queue_intest_index_and_number[pop_label] = (queue_intest_index_tmp+1) % storage_intest_number_array[pop_label]
                queue_intest_index_and_number[2+pop_label] += 1
        elif safety_label == 1:
            while n_unlabeleddata_list > 0:
                pop_label = queue_label_list.pop(0)
                assert pop_label == 0
                pop_state = queue_data_list.pop(0)
                queue_counter_list.pop(0)
                n_unlabeleddata_list -= 1
                queue_intest_index_tmp = queue_intest_index_and_number[safety_label]
                storage_intest_state_list[0][unsafe_intest_index_list[0][queue_intest_index_tmp]] = pop_state
                queue_intest_index_and_number[safety_label] = (queue_intest_index_tmp+1) % storage_intest_number_array[safety_label]
                queue_intest_index_and_number[2+safety_label] += 1
        elif n_unlabeleddata_list == 0: 
            assert n_unlabeleddata_list >= 0
            pass
        else:
            for i in range(n_unlabeleddata_list): queue_counter_list[i] += 1
            if queue_counter_list[0] == intest_trajectory_step:
                pop_label = queue_label_list.pop(0)
                assert pop_label == 0
                pop_state = queue_data_list.pop(0)
                queue_counter_list.pop(0)
                n_unlabeleddata_list -= 1
                queue_intest_index_tmp = queue_intest_index_and_number[pop_label]
                storage_intest_state_list[0][safe_intest_index_list[0][queue_intest_index_tmp]] = pop_state
                queue_intest_index_and_number[pop_label] = (queue_intest_index_tmp+1) % storage_intest_number_array[pop_label]
                queue_intest_index_and_number[2+pop_label] += 1
    
    
    def internal_test_case_sampling(): 
        nonlocal storage_intest_state_list, storage_intest_label_list, safe_intest_index_list, unsafe_intest_index_list, current_intest_state_list, current_intest_label_list
        nonlocal current_safe_intest_index_list, current_unsafe_intest_index_list
        for i in range(n_NN):
            current_safe_intest_index_list[i] = np.random.randint(0, high=storage_intest_number_array[0], size=intest_batch_size_array[0], dtype=int)
            current_unsafe_intest_index_list[i] = np.random.randint(storage_intest_number_array[0], high=storage_intest_number_array[0]+storage_intest_number_array[1], size=intest_batch_size_array[1], dtype=int)
            current_intest_label_list[i][:intest_batch_size_array[0]] = storage_intest_label_list[i][current_safe_intest_index_list[i]]
            current_intest_state_list[i][:intest_batch_size_array[0]] = storage_intest_state_list[i][current_safe_intest_index_list[i]]
            current_intest_label_list[i][intest_batch_size_array[0]:] = storage_intest_label_list[i][current_unsafe_intest_index_list[i]]
            current_intest_state_list[i][intest_batch_size_array[0]:] = storage_intest_state_list[i][current_unsafe_intest_index_list[i]]
        return [current_intest_label_list[i].shape[0] for _ in range(n_NN)] 


    def POF_batch_making(state, a_candidate_dense, mode):
        nonlocal device, SNN_list
        if mode == "it_batch":
            nonlocal current_intest_state_list, current_intest_label_list
            intest_state_list = current_intest_state_list
            intest_label_list = current_intest_label_list
        elif mode == "it_total":
            nonlocal storage_intest_state_list, storage_intest_label_list
            intest_state_list = storage_intest_state_list
            intest_label_list = storage_intest_label_list
        else: return
        result_batch_list = list()
        result_sample_array = np.empty([0,n_class])
        result_intest_list = list()
        for nn_order in range(n_NN):
            state_action_batch_temp = np.concatenate((np.repeat(state.reshape(1,-1), repeats=n_sample_of_action, axis=0), a_candidate_dense.reshape(n_sample_of_action,-1)), axis=1)
            batch_temp = np.concatenate((state_action_batch_temp, intest_state_list[nn_order]), axis=0)
            batch_temp[:,60:62]=np.clip(batch_temp[:,60:62],-1,1)
            for i in range(batch_temp.shape[0]):
                if batch_temp[i][60]>0:
                    batch_temp[i][60]=1.0
                else:
                    batch_temp[i][60]=-1.0
            result_batch_list.append(SNN_list[nn_order](torch.FloatTensor(batch_temp).to(device)))
            result_sample_tmp, result_intest_tmp = np.split(result_batch_list[nn_order].detach().cpu().numpy(), [n_sample_of_action], axis=0)
            result_intest_list.append(np.concatenate((intest_label_list[nn_order], result_intest_tmp), axis=1))
            result_sample_array = np.append(result_sample_array, result_sample_tmp, axis=0)       
        return result_batch_list, result_sample_array, result_intest_list


    def POF_update_loading(status, shared_string, intest_size_list):
        nonlocal device
        if status == "train": n_status = 0
        elif status == "val": n_status = 1
        else: n_status = -1
        
        if shared_string == None: # txt version
            return
        elif n_status == -1: # string version - error
            return
        elif n_status == 1: # string version - val
            shared_string = shared_string.split("\n")
            temp = shared_string[0]
            pof_action = int(temp)
            return pof_action
        else: # string version - train
            shared_string = shared_string.split("\n")
            pof_action = int(shared_string[0])
            grad_array_list = []
            unitsize_tmp = unit_intest_batch_size_array[0] + unit_intest_batch_size_array[1]
            size = n_sample_of_action + unitsize_tmp
            for _ in range(n_NN): grad_array_list.append(np.array(shared_string[1:1+size*n_class]).astype(float).reshape((size,n_class)))
            if (gradient_offset):
                for i in range(n_NN):
                    index_tmp_1 = n_sample_of_action 
                    index_tmp_2 = index_tmp_1 + unitsize_tmp
                    grad_array_list[i][index_tmp_1:index_tmp_2] = grad_array_list[i][index_tmp_1:index_tmp_2] - np.repeat([np.mean(grad_array_list[i][index_tmp_1:index_tmp_2], axis=0)], unitsize_tmp, axis=0)
            return pof_action, grad_array_list


    def one_step_for_batch_training(pof_gradient_list, network_output_list):
        nonlocal storage_intest_accgd_list, current_safe_intest_index_list, current_unsafe_intest_index_list
        nonlocal current_intest_size, current_intest_label_list
        # sample gradient accumulation && internal test case gradient accumulation
        for i in range(n_NN): 
            temp_grad_array = torch.from_numpy(pof_gradient_list[i]).to(device)
            if True in temp_grad_array.isnan(): 
                print("NAN in GRAD tensor")
                return
            pof_loss = torch.sum(torch.mul(temp_grad_array, network_output_list[i]))
            regularization_term = (torch.sum(network_output_list[i]**2))/(torch.numel(network_output_list[i]))
            total_loss = pof_loss + regularization_term
            total_loss.backward(retain_graph=True)
            network_grad_list_np = network_output_list[i].grad.cpu().detach().numpy()
            for j in range(intest_batch_size_array[0]):
                 storage_intest_accgd_list[i][current_safe_intest_index_list[i][j]] += network_grad_list_np[j+n_sample_of_action]
            for j in range(intest_batch_size_array[1]):
                storage_intest_accgd_list[i][current_unsafe_intest_index_list[i][j]] += network_grad_list_np[j+intest_batch_size_array[0]+n_sample_of_action]

                   
    def update_pof():
        assert batch_training
        nonlocal SNN_list, optimizer_list
        for i in range(n_NN):
            mpi_avg_grads(SNN_list[i])
            scheduler_list_with_warmup[i](None)
            optimizer_list[i].step()
            optimizer_list[i].zero_grad()


    def POF_output_saving(writing_mode, status, x_vel_list, epoch, step, mu, std, result_sample_array, result_intest_list,hazard_check) -> str:
        if status == "train": n_status = 0
        elif status == "val": n_status = 1
        else: n_status = -1
        
        if writing_mode=="txt":
            pof_output_path = pof_data_path + "POF_OUTPUT_shm_"+str(epoch)+"_"+str(step)+'_'+str(proc_id())+".txt"
            file =  open(pof_output_path, "w")
            file.write(str(n_status))
            file.write("\n")
            file.write(str(result_intest_list[0].shape[0])) 
            file.write("\n")
            if hazard_check and x_vel_list < 0: file.write(str(-10020))
            elif hazard_check and x_vel_list > 0: file.write(str(-10102))
            elif x_vel_list < -default_action_margin: file.write(str(-102))
            elif x_vel_list > default_action_margin: file.write(str(-20))
            else: file.write(str(-61))
            file.write("\n")
            file.write(str(mu))
            file.write("\n")
            file.write(str(std))
            file.write("\n")  
            file.write(str(n_class))
            file.write("\n")
            for i in range(result_sample_array.shape[0]):
                for j in range(result_sample_array.shape[1]):                
                    file.write(str(result_sample_array[i][j]))
                    file.write(" ")
                file.write("\n")
            for i in range(len(context)):
                file.write(str(context[i]))
                file.write(" ")
            file.write("\n")
            for i in range(len(result_intest_list)):
                for j in range(result_intest_list[i].shape[0]):
                    for k in range(result_intest_list[i].shape[1]):
                        file.write(str(result_intest_list[i][j][k]))
                        file.write(" ")
                    file.write("\n")
            file.close()
            return "txt complete"
        elif writing_mode=="shm":
            lines = [
                str(n_status),
                str(result_intest_list[0].shape[0]),
            ]
            if hazard_check and x_vel_list < 0: lines.append("-10020")
            elif hazard_check and x_vel_list > 0: lines.append("-10102")
            elif x_vel_list < -default_action_margin: lines.append("-102")
            elif x_vel_list > default_action_margin: lines.append("-20")
            else: lines.append("-61")
            lines.extend([
                str(max(min(mu[0].detach().item(), 4.0), -4.0)),
                str(max(min(mu[1].detach().item(), 4.0), -4.0)),
                str(std[0].detach().item()),
                str(std[1].detach().item()),
                str(n_class),
            ])
            lines.extend(map(str, result_sample_array.flatten()))
            lines.extend(map(str, context))
            for arr in result_intest_list: lines.extend(map(str, arr.flatten()))
            shared_string = "\n".join(lines) + "\n"
            return shared_string
        else: return "nothing complete"


    def POF_storage_it_output_saving(epoch, step):
        nonlocal device, SNN_list, storage_intest_number, storage_intest_state_list, storage_intest_label_list
        total_intest_output_list = list()
        print_unit = 200000 if storage_intest_number > 200000 else storage_intest_number 
        print_tmp = int(storage_intest_number / print_unit)
        print("Print tmp: ", print_tmp)
        for i in range(n_NN):
            total_intest_output = np.zeros([storage_intest_number, n_class+1])
            for j in range(print_tmp):
                batch_temp=np.copy(storage_intest_state_list[i][print_unit*j:print_unit*(j+1),:])
                batch_temp[:,60:62]=np.clip(batch_temp[:,60:62],-1,1)
                for k in range(batch_temp.shape[0]):
                    if batch_temp[k][60]>0:
                        batch_temp[k][60]=1.0
                    else:
                        batch_temp[k][60]=-1.0
                total_intest_output[print_unit*j:print_unit*(j+1),:] = np.concatenate((storage_intest_label_list[i][print_unit*j:print_unit*(j+1),:], SNN_list[i](torch.FloatTensor(batch_temp).to(device)).detach().cpu().numpy()), axis=1)
            total_intest_output_list.append(total_intest_output)
        n_color_array = np.array(["blue", "red"])
        colormap = [0]*storage_intest_number_array[0] + [1]*storage_intest_number_array[1]
        figone, axsone = plt.subplots(n_NN, n_class, figsize=(15, 9))
        for i in range(n_NN):
            for j in range(n_class):
                axsone[j].scatter(range(0,storage_intest_number), total_intest_output_list[i][0:storage_intest_number,j+1], color=n_color_array[colormap], s=1.0)
                axsone[j].set_title(str(i+1)+'_NN, '+str(j+1)+'_class')
        figone.suptitle('STORATE IT OUTPUT PLOT', fontsize=16)
        plt.savefig(pof_data_path + '/data_format/STORAGE_IT_OUTPUT_PLOT_'+str(epoch)+'_'+str(step)+'_'+str(proc_id())+'.png')
        plt.close('all')
        plt.clf()
        if proc_id()==0:
            pof_output_path = pof_data_path + "POF_OUTPUT_"+str(epoch)+"_"+str(step)+'_'+str(proc_id())+".txt"
            file = open(pof_output_path, "w")
            file.write(str(storage_intest_number))
            file.write("\n")
            for i in range(len(total_intest_output_list)):
                for j in range(total_intest_output_list[i].shape[0]):
                    for k in range(total_intest_output_list[i].shape[1]): 
                        file.write(str(total_intest_output_list[i][j][k]))
                        file.write(" ")
                    file.write("\n")
            file.close()


    def POF_storage_it_update_saving(epoch, step):
        nonlocal storage_intest_number, storage_intest_label_list, storage_intest_accgd_list
        total_intest_grad_list = list()
        for i in range(n_NN): total_intest_grad_list.append(np.concatenate((storage_intest_label_list[i], storage_intest_accgd_list[i]), axis=1)) #, dtype=float)) #####
        # total gradient img
        n_color_array = np.array(["blue", "red"])
        colormap = [0]*storage_intest_number_array[0] + [1]*storage_intest_number_array[1]
        figone, axsone = plt.subplots(n_NN, n_class, figsize=(15, 9))
        for i in range(n_NN):
            for j in range(n_class):
                axsone[j].scatter(range(0,storage_intest_number), total_intest_grad_list[i][0:storage_intest_number,j+1], color=n_color_array[colormap], s=1.0)
                axsone[j].set_title(str(i+1)+'_NN, '+str(j+1)+'_class')
        figone.suptitle('STORATE IT GRADIENT PLOT', fontsize=16)
        plt.savefig(pof_data_path + '/data_format/STORAGE_IT_GRADIENT_PLOT_'+str(epoch)+'_'+str(step)+'_'+str(proc_id())+'.png')
        plt.close('all')
        plt.clf()
        # total gradient txt for plot
        if proc_id()==0:
            pof_udpate_path = pof_data_path + "POF_UPDATE_"+str(epoch)+"_"+str(step)+'_'+str(proc_id())+".txt"
            file =  open(pof_udpate_path, "w")
            for i in range(n_NN):
                for j in range(total_intest_grad_list[i].shape[0]):
                    for k in range(total_intest_grad_list[i].shape[1]):
                        file.write(str(total_intest_grad_list[i][j][k].item()))
                        file.write(" ")
                    file.write("\n")
                file.write("\n") 
                storage_intest_accgd_list[i].fill(0)
            file.close()


    tmp_count = 0
    def pof_batch_training(epoch, step, state, x_vel, a_candidate_dense, mu, std, hazard_check):
        assert batch_training
        nonlocal pof_shared_memory_1, pof_shared_memory_2
        nonlocal tmp_count
        intest_size_list = internal_test_case_sampling()
        result_batch_networkoutput_list, result_sample, result_intest_list = POF_batch_making(state, a_candidate_dense, "it_batch")
        for index in range(n_NN):
            result_batch_networkoutput_list[index].retain_grad()
            if True in result_batch_networkoutput_list[index].isnan():
                print("NAN in OUTPUT tensor")
                return
        shared_string_1 = POF_output_saving("shm", "train", x_vel, 0, 0, mu,std, result_sample, result_intest_list,hazard_check)
        pof_shared_memory_1.write(shared_string_1.encode())
        sem_1.V() ## S1-1 & S2-0
        sem_2.P() ## S1-0 & S2-0
        shared_string_2 = pof_shared_memory_2.read().decode()
        selected_action, grad_array_list = POF_update_loading("train", shared_string_2, intest_size_list)
        one_step_for_batch_training(grad_array_list, result_batch_networkoutput_list)
        tmp_count += 1
        update_pof()
        if tmp_count == 100000:
            if proc_id()==0: 
                POF_output_saving("txt", "train", x_vel, epoch, step, mu, std, result_sample, result_intest_list,hazard_check)
                POF_storage_it_output_saving(epoch, step)
                POF_storage_it_update_saving(epoch, step)            
                tmp_count = 0
            gc.collect()
        return selected_action

    # Sync params across processes
    sync_params(ac_ppo)

    # Count variables
    var_counts = tuple(core.count_vars(module) for module in [ac_ppo.pi, ac_ppo.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

    # Set up experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)
    
    # Set up function for computing PPO policy loss
    def compute_loss_pi(data):
        obs, act, adv, cadv,  logp_old = data['obs'], data['act'], data['adv'], data['cadv'] ,data['logp']
        cur_cost = data['cur_cost']

        cost_limit = 5000000
        # Policy loss
        pi, logp = ac_ppo.pi(obs, act)
        ratio = torch.exp(logp - logp_old)

        clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv
        loss_rpi = (torch.min(ratio * adv, clip_adv)).mean()

        # clip_cadv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * cadv
        # loss_cpi = (torch.min(ratio * cadv, clip_cadv)).mean()
        loss_cpi = ratio*cadv
        loss_cpi = loss_cpi.mean()
      
        pi_objective = loss_rpi
        loss_pi = -pi_objective

        cost_deviation = (cur_cost - cost_limit)

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1+clip_ratio) | ratio.lt(1-clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, cost_deviation, pi_info

    # Set up function for computing value loss
    def compute_loss_v(data):
        obs, ret, cret = data['obs'], data['ret'], data['cret']
        return ((ac_ppo.v(obs) - ret)**2).mean(),((ac_ppo.vc(obs) - cret)**2).mean()

    def update():
        cur_cost = logger.get_stats('EpCost')[0]
        data = buf.get()
        data['cur_cost'] = cur_cost

        pi_l_old, cost_dev, pi_info_old = compute_loss_pi(data)

        pi_l_old = pi_l_old.item()
        v_l_old, cv_l_old = compute_loss_v(data)
        v_l_old, cv_l_old = v_l_old.item(), cv_l_old.item() 
        
        

        # Train policy with multiple steps of gradient descent
        train_pi_iters=80
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, _,pi_info = compute_loss_pi(data)
            kl = mpi_avg(pi_info['kl'])
            if kl > 1.2 * target_kl:
                logger.log('Early stopping at step %d due to reaching max kl.'%i)
                break

            loss_pi.backward()
            mpi_avg_grads(ac_ppo.pi)    # average grads across MPI processes
            pi_optimizer.step()

        logger.store(StopIter=i)

        # Value function learning
        train_v_iters=80
        for i in range(train_v_iters):
            
            loss_v, loss_vc = compute_loss_v(data)
            vf_optimizer.zero_grad()
            loss_v.backward()
            mpi_avg_grads(ac_ppo.v)   # average grads across MPI processes
            vf_optimizer.step()

            cvf_optimizer.zero_grad()
            loss_vc.backward()
            mpi_avg_grads(ac_ppo.vc)
            cvf_optimizer.step()

        # Log changes from update
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old, LossV=v_l_old,
                     KL=kl, Entropy=ent, ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))
        

    print("    START NEW SESSION!")
    start_time = time.time()
    o, ep_ret,ep_cret, ep_len = env.reset(), 0, 0, 0 
    intest_rejected_count = 0
    
    for i in range(n_NN): optimizer_list[i].zero_grad()
    for epoch in range(backup_epoch, epochs):
        for t in range(local_steps_per_epoch):
            if render and proc_id()==0:
                env.render()

            a_candidate, a_candidate_dense, v, vc, pi_acppo, mu, std = ac_ppo.stepv2(torch.as_tensor(o, dtype=torch.float32))
            hazard_check = not all(one < hazard_check_margin for one in o[22:38]) # self.observe_hazards
            
            selected_action = pof_batch_training(epoch, t+1, o, o[57], a_candidate_dense, mu, std, hazard_check)
        
            reward_tmp, done_tmp, cost_tmp = 0, False, 0
            if selected_action == -20:
                action_tmp = [-1.,0.]
                logp_tmp = -10
            elif selected_action == -102:
                action_tmp = [1.,0.]
                logp_tmp = -10
            elif selected_action == -61:
                action_tmp = [0.,0.]
                logp_tmp = -10
            else:
                action_tmp = a_candidate[selected_action]
                logp_tmp = ac_ppo.pi._log_prob_from_distribution(pi_acppo, torch.Tensor(action_tmp))
            
            next_state_tmp, reward_tmp, done_tmp, info_tmp = env.step(action_tmp)
            cost_tmp = info_tmp.get('cost', 0)
            goal_checking_tmp = info_tmp.get('goal_met')

            buf.store(o, action_tmp, reward_tmp, cost_tmp, v, vc, logp_tmp)
            logger.store(VVals=v)
            logger.store(CVVals=vc)

            if cost_tmp > 0: safety_label_tmp = 1
            else: safety_label_tmp =  0
            win.Lock(proc_id(), lock_type=MPI.LOCK_EXCLUSIVE)
            wintwo.Lock(proc_id(), lock_type=MPI.LOCK_EXCLUSIVE)
            queue_storage_intest(action_tmp, safety_label_tmp, o.reshape(1,-1), t+1, goal_checking_tmp)
            win.Unlock(proc_id())
            wintwo.Unlock(proc_id())
            win.Fence()

            o = next_state_tmp
            ep_ret += reward_tmp
            ep_cret += cost_tmp
            ep_len += 1
            if safety_label_tmp == 1: o = env.reset()

            timeout = ep_len == max_ep_len
            terminal_tmp = done_tmp or timeout
            epoch_ended = t==local_steps_per_epoch-1
            if terminal_tmp or epoch_ended:
                if epoch_ended and not(terminal_tmp) and proc_id() == 0:
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                    
                print("proc#%d: RESET at epoch:%d, local_epoch:%d" %(proc_id(), epoch, t+1))
                if timeout or epoch_ended:
                    next_state_tmp_to_torch = torch.as_tensor(next_state_tmp, dtype=torch.float32)
                    v,vc = ac_ppo.step(next_state_tmp_to_torch, -1)
                else:
                    v = 0
                    vc = 0
                buf.finish_path(last_val=v,last_cval=vc)
                if terminal_tmp: logger.store(EpRet=ep_ret, EpLen=ep_len, EpCost=ep_cret)
                o, ep_ret, ep_cret, ep_len = env.reset(), 0, 0, 0
                
                if (n_unlabeleddata_list > 0): 
                    for _ in range(n_unlabeleddata_list):
                        queue_label_list.pop(0)
                        queue_data_list.pop(0)
                        queue_counter_list.pop(0)
                    n_unlabeleddata_list = 0
                    assert(len(queue_label_list) == 0)
                    assert(len(queue_data_list) == 0)
                    assert(len(queue_counter_list) == 0)

        if ((epoch+1) % save_freq == 0) and (proc_id() == cpu-1) and (pof_checkpoint != -1):
            torch.save(
                {
                    # main info section
                    "epoch": epoch,
                    "context": context,
                    # AC section
                    "ac_ppo": ac_ppo.state_dict(),
                    "pi_optimizer": pi_optimizer.state_dict(),
                    "penalty_param": "NONE",
                    "penalty_optimizer": "NONE",
                    "vf_optimizer": vf_optimizer.state_dict(),
                    "cvf_optimizer": cvf_optimizer.state_dict(),
                    # SNN section
                    "SNN_0": SNN_list[0].state_dict(),
                    "Soptim_0": optimizer_list[0].state_dict(),
                    "Ssch_0": scheduler_list[0].state_dict(), 
                    "Sschww_0": scheduler_list_with_warmup[0].state_dict(),
                    # internal test case section
                    "storage_intest_number":storage_intest_number,
                    "storage_intest_state_list": storage_intest_state_list,
                    "storage_intest_label_list": storage_intest_label_list,
                    "safe_intest_index_list": safe_intest_index_list,
                    "unsafe_intest_index_list": unsafe_intest_index_list,
                    "current_intest_state_list": current_intest_state_list,
                    "current_intest_label_list": current_intest_label_list,
                    "current_intest_size": current_intest_size,
                },
                f"/home/user/POF_data_{pof_section}/checkpoints/pof-checkpoint-{pof_checkpoint}_epoch-{epoch}.pt", pickle_protocol=4
            )
            print("    SAVE CHECKPOINT --- pof-checkpoint-%d_epoch-%d.pt" %(pof_checkpoint, epoch))

        update()

        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpCost',with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time()-start_time)
        logger.dump_tabular()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Safexp-PointGoal1-v0')
    parser.add_argument('--hid', type=int, default=256)
    parser.add_argument('--l', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=30)
    parser.add_argument('--steps', type=int, default=4000)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--exp_name', type=str, default='ppo_point_train_')
    parser.add_argument('--mode', type=str, default='ppo_point_nolag')
    parser.add_argument('--checkpoint', type=str, default='10000')
    args = parser.parse_args()
    
    mpi_fork(args.cpu)  # run parallel code with mpi

    from utils.run_utils import setup_logger_kwargs
    if args.mode != "error": args.exp_name += args.mode.split("_")[-1]
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    num_steps = 1e6*args.cpu
    steps_per_epoch = 1000*args.cpu
    epochs = int(num_steps / steps_per_epoch)
    ppo(lambda : gym.make(args.env), actor_critic=core.MLPActorCritic_ppo_point_train,
        ac_kwargs=dict(hidden_sizes=[args.hid]*args.l), gamma=args.gamma, 
        seed=args.seed, steps_per_epoch=steps_per_epoch, epochs=epochs,
        logger_kwargs=logger_kwargs, agent_checkpoint_mode=args.mode, ppo_checkpoint=args.checkpoint, cpu=args.cpu)
